{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the A2C Agent\n",
    "This notebook is for visualizing the A2C agent playing the pacman game and making sure that model is working. This is not for visualizing the imagination augmented agent. \n",
    "\n",
    "First start off by importing the necessary modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "from a2c import get_actor_critic\n",
    "import env_model\n",
    "from common.minipacman import MiniPacman\n",
    "import tensorflow as tf\n",
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next create our environment. We want don't want to have multiprocessing (hence `nenvs=1`) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "nenvs = 1\n",
    "nsteps=5\n",
    "\n",
    "done = False\n",
    "env = MiniPacman('regular', 1000)\n",
    "ob_space = env.observation_space.shape\n",
    "nw, nh, nc = ob_space\n",
    "ac_space = env.action_space\n",
    "\n",
    "states = env.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Helper function to display an image to the Jupyter Notebook so we can see the game being played in our browser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import time \n",
    "\n",
    "def displayImage(image, step, reward):\n",
    "    clear_output(True)\n",
    "    s = \"step: \" + str(step) + \" reward: \" + str(reward)\n",
    "    plt.figure(figsize=(10,3))\n",
    "    plt.title(s)\n",
    "    plt.imshow(image)\n",
    "    plt.show()\n",
    "    time.sleep(0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load in the saved weights and see the game being played! Replace the weights I saved with whatever you want to use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPQAAADSCAYAAAB92u3VAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAD8BJREFUeJzt3XvQXHV9x/H3hyTcIhIgECEBHqQMLVIntBnQFgHlIqA0\nWIvCoBKlo84YvI4a7bTQGdqh9UIvOigKQgVlKJWRoQhE5CLVAkkI14Ck4ZKEkITaUlBuKd/+cX6P\nbLbPk2fPZfc8+8vnNbPznD3n7Nnv+e1+93fO2d9+H0UEZpaHbdoOwMya44Q2y4gT2iwjTmizjDih\nzTLihDbLiBPaJg1JI5JC0tS2YxlWTuhxSDpH0mUDeq7tJF0k6XFJz0paLumEjuWjb/TnOm5/3rH8\ns5LuT499VNJnBxF32yQtlLRE0ouSLhlj+XskrUjt8qCkkzuWZdlm/iScHKYCq4EjgSeAE4ErJf1u\nRDzWsd6MiNg0xuMFfAC4F9gfuFHS6oi4YqInljR1nG32VUPP+yRwLvB2YIeu7c8GLgPmA9dTtOk/\nSxqJiA3UaLNJLSK26hvweWAt8CzwMHA0cDzwEvAy8BxwT1p3Z+AiYF16zLnAlLRsAfBvwNeAZ4CH\ngKNrxHUv8O40PQIEMLXHx/4D8I/jLBvd1pkUHx63pflvAn4G/DdwD3BUmv9W4L6Oxy8G7uq4/1Pg\n5DS9CPiP1JYPAu/qWG+0fc4H/nO07YAvA08Dq4CPldnPjm2fC1zSNe8wYEPXvI3Am8u22TDdWg+g\n1Z2HAyl6xr3S/RFg/zR9DnBZ1/pXA98EpgN7AHcCH0nLFgCbgE8B04D3psTeNS1fBFzbY1yzgBeA\n3+6IK9KHyBrgO8DMcR4r4G7go+MsH93WP6X92AGYnZLsRIrTsGPT/d3T8heAmWm/1qc4dkrLngd2\nS9s+BdgrbeO9wK+APbva5yyKI5IdgI9SfPDtDewK3NyZ0L222TgJPQW4FTgpTZ+c2m562TYbplvr\nAbS68/BbwAbgGGBa17LNEjol2YvADh3zTgNuTtMLKA4B1bH8TuD9JWOaBvwY+GbHvNcA81IizAKu\nAm4Y5/F/SdHDbjfO8tGEfn3HvM8D3+1a7wbgjDT9U+CPKXrxG4ErKY5i3grcu4V9WQ7M72ifJ7qW\n/6QziYDjaKiHTvPPpDjC2gT8GnhHlTYbpttWfQ4dESslfZIied8g6Qbg0xHx5Bir70uRbOskjc7b\nhqKHH7U20jskeZyix+qJpG2A71Ic7i/siPM5YEm6u17SwhTHThHxbMfjF1KcF74lIl6c4Ok6494X\nOEXSSR3zplH0mFD0dEdR9HC3Av9Fcb7/Yro/+vwfAD5N8aEBxQfRzHGeE4q26Zz3+AQx90zSMcDf\npriXAb8PXCPphIhY3rFemTab9Lb6q9wR8b2IOJziTR3A34wu6lp1NcUbeGZEzEi310bEGzrWma2O\nbAf2oei1J5QedxFFD/zuiHh5S2Gnv795/SR9iOIQ9eiIWNPDU3bu32qKHnpGx216RJyXlo8m9BFp\n+laKhD4yTSNpX+BbFB9Eu0XEDOB+isPZsZ4TimsRe3fc36eHuHs1l+L6wJKIeCUi7gLuoDgaI8Vc\nts0mva06oSUdKOltkrajOE98HnglLV4PjKRek4hYR3G4+RVJr5W0jaT9JR3Zsck9gI9LmibpFOB3\ngOt6DOeCtP5JEfF8V5yHpVi3kbQbxQWcWyLimbT8dOCvgWMjYlX5luAy4CRJb5c0RdL2ko6SNCct\n/xnF9YZDgTsj4gGKD8DDgNvSOtMpEnZjiumDwMETPO+VFO01R9IuFMnVM0lTJW1PcY48GvfoUedd\nwOGS5qZ1DwHeQnGxsYk2m5zaPuZv8wa8keI891ngl8C1vHqBbDfgdorDy2Vp3s4UibeG4oLX3cCp\nadkCNr/K/QvguI7n+iLwo3HiGD06eIHinG/0dnpafhrwKMVFpnUUF7Re1/H4R3n1ivzo7RvjPNcI\nY5ynUiTnrakdNgL/CuzTsfznpOsF6f5VwIqubfxVevzTwFfT9v60o31u71p/Kq9e9X6UrqvcW2qz\ntPyctH7n7ZyO5QuBlen1XQV8pkqbDdNNaeesJkkLKN68h7cdi229tupDbrPcOKHNMuJDbrOMuIc2\ny4gT2iwjAx0pNnPmzBgZGRnkU5plYenSpU9HxO4TrTfQhB4ZGWHJkiUTr2hmm5HU07DYWofcko6X\n9LCklZJKjfIxs+ZVTmhJU4CvAycABwGnSTqoqcDMrLw6PfShwMqIWBURLwFXUFSHMLOW1Eno2Wz+\n07c1ad5mJH041X1asnHjxhpPZ2YT6fvXVhFxYUTMi4h5u+8+4UU6M6uhTkKvZfPfss5J88ysJXUS\n+i7gAEn7SdoWOBW4ppmwzKyKyt9DR8SmVL7lBoofmF8cxQ/fzawltQaWRMR19F6Ro5TNK/m0o9cf\nrkyCUA0o8zujYXp/leGx3GYZcUKbZcQJbZYRJ7RZRpzQZhlxQptlxAltlhEntFlGnNBmGXFCm2Uk\ni38nW2YIXdtD/lwGvby2R2kO0/vLPbRZRurUFNtb0s2SHpT0gKRPNBmYmZVX55B7E8W/51wmaSdg\nqaTFEfFgQ7GZWUmVe+iIWBcRy9L0s8AKxqgpZmaD08g5tKQR4BDgjia2Z2bV1E5oSa8B/gX4ZET8\nzxjLXfXTbEDq/ueMaRTJfHlE/GCsdVz102xw6lzlFnARsCIivtpcSGZWVZ0e+g+B9wNvk7Q83U5s\nKC4zq6BO1c/bgfYrrZnZb2Qx9LNfo+36MYyv7WGMOSvVtmWG4A7Ri+ahn2YZcUKbZcQJbZYRJ7RZ\nRpzQZhlxQptlxAltlhEntFlGnNBmGXFCm2Uki6Gf5f7Rd+8r97rdMiMDXfWzvF7bt9z7oEQApYaJ\nlli3D9xDm2WkiYolUyTdLenaJgIys+qa6KE/QVEg0MxaVrcE0RzgHcC3mwnHzOqo20P/HfA54JUG\nYjGzmurUFHsnsCEilk6wnqt+mg1I3ZpifyTpMeAKitpil3Wv5KqfZoNT5z9nfCEi5kTECHAq8JOI\neF9jkZlZaf4e2iwjjYwUi4hbgFua2JaZVZfF0M9+6UexxyEqIDl03LY+5DbLihPaLCNOaLOMOKHN\nMuKENsuIE9osI05os4w4oc0y4oQ2y4gT2iwjW93Qz7arbrpCaMHt0B/uoc0yUrem2AxJV0l6SNIK\nSW9uKjAzK6/uIfffA9dHxJ9I2hbYsYGYzKyiygktaWfgCGABQES8BLzUTFhmVkWdQ+79gI3Ad1Kh\n/W9Lmt69kosEmg1OnYSeCvwecEFEHAL8CljUvZKLBJoNTp2EXgOsiYg70v2rKBLczFpSp+rnU8Bq\nSQemWUcDDzYSlZlVUvcq91nA5ekK9yrgg/VDMrOqaiV0RCwH5jUUi5nVlMXQT1d7zJtf39556KdZ\nRpzQZhlxQptlxAltlhEntFlGnNBmGXFCm2XECW2WESe0WUayGCk2TEXkPOqpPL++vXMPbZYRJ7RZ\nRupW/fyUpAck3S/p+5K2byowMyuvckJLmg18HJgXEQcDU4BTmwrMzMqre8g9FdhB0lSKEr5P1g/J\nzKqqU4JoLfBl4AlgHfBMRNzYvZ6rfpoNTp1D7l2A+RTlfPcCpkt6X/d6rvppNjh1DrmPAR6NiI0R\n8TLwA+APmgnLzKqok9BPAG+StKMkUVT9XNFMWGZWRZ1z6DsoanEvA+5L27qwobjMrIK6VT/PBs5u\nKJaBUNtj8+h9HGProU4SbbfDMA099Ugxs4w4oc0y4oQ2y4gT2iwjTmizjDihzTLihDbLiBPaLCNO\naLOMOKHNMpJF1c9+iR7H/JUZmjhMwwgni17bt0zbtj2ctF/cQ5tlZMKElnSxpA2S7u+Yt6ukxZIe\nSX936W+YZtaLXnroS4Dju+YtAm6KiAOAm9J9M2vZhAkdEbcBv+yaPR+4NE1fCpzccFxmVkHVc+hZ\nEbEuTT8FzGooHjOrofZFsSguBY97fdFVP80Gp2pCr5e0J0D6u2G8FV3102xwqib0NcAZafoM4IfN\nhGNmdfTytdX3gZ8DB0paI+lM4DzgWEmPUJTzPa+/YZpZLyYcKRYRp42z6OiGYzGzmrIY+tmvSp79\n2OzQDTksM1S15X3rV9sO02vmoZ9mGXFCm2XECW2WESe0WUac0GYZcUKbZcQJbZYRJ7RZRpzQZhlx\nQptlZBIP/ex9zGG/qj32ul1X/Uz6tG99qfpZ4vlLjX5teZioe2izjFSt+vklSQ9JulfS1ZJm9DdM\nM+tF1aqfi4GDI+KNwC+ALzQcl5lVUKnqZ0TcGBGb0t1/B+b0ITYzK6mJc+gPAT9qYDtmVlOthJb0\nZ8Am4PItrOOqn2YDUjmhJS0A3gmcHlv4r26u+mk2OJW+h5Z0PPA54MiI+HWzIZlZVVWrfn4N2AlY\nLGm5pG/0OU4z60HVqp8X9SEWM6tpEg/97F3bw+3KGKZYs1bidRiml8xDP80y4oQ2y4gT2iwjTmiz\njDihzTLihDbLiBPaLCNOaLOMOKHNMuKENsvIpB36OUzVMYcpVivk+pq5hzbLSKWqnx3LPiMpJM3s\nT3hmVkbVqp9I2hs4Dnii4ZjMrKJKVT+T8ymqlmR6NmI2fCqdQ0uaD6yNiHt6WNdFAs0GpHRCS9oR\n+CLwF72s7yKBZoNTpYfeH9gPuEfSYxRF9pdJel2TgZlZeaW/h46I+4A9Ru+npJ4XEU83GJeZVVC1\n6qeZTUJVq352Lh9pLBozq0Vb+KcXzT+ZtBF4vGv2TCDXw/Vc9837NXj7RsSEV5UHmtBjBiAtiYh5\nrQbRJ7num/dr8vJYbrOMOKHNMjIZEvrCtgPoo1z3zfs1SbV+Dm1mzZkMPbSZNaTVhJZ0vKSHJa2U\ntKjNWJok6TFJ96V/tbuk7XjqGOv38JJ2lbRY0iPp7y5txljFOPt1jqS16XVbLunENmOsorWEljQF\n+DpwAnAQcJqkg9qKpw/eGhFzh/1rEMb+Pfwi4KaIOAC4Kd0fNpcwxu/8gfPT6zY3Iq4bcEy1tdlD\nHwqsjIhVEfEScAUwv8V4bAzj/B5+PnBpmr4UOHmgQTVgC7/zH2ptJvRsYHXH/TVpXg4C+LGkpZI+\n3HYwfTArItal6aeAWW0G07CzJN2bDsmH7lTCF8X64/CImEtxOvExSUe0HVC/RPE1SS5flVwAvB6Y\nC6wDvtJuOOW1mdBrgb077s9J84ZeRKxNfzcAV1OcXuRkvaQ9AdLfDS3H04iIWB8R/xsRrwDfYghf\ntzYT+i7gAEn7SdoWOBW4psV4GiFpuqSdRqcpCin+v4qpQ+4a4Iw0fQbwwxZjaczoh1TyLobwdWut\n0H5EbJK0ELgBmAJcHBEPtBVPg2YBV0uCon2/FxHXtxtSden38EcBMyWtAc4GzgOuTL+Nfxx4T3sR\nVjPOfh0laS7FKcRjwEdaC7AijxQzy4gvipllxAltlhEntFlGnNBmGXFCm2XECW2WESe0WUac0GYZ\n+T/qikSbqcG2DgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x120fe9a58>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total reward 182\n"
     ]
    }
   ],
   "source": [
    "with tf.Session() as sess:\n",
    "    actor_critic = get_actor_critic(sess, nenvs, nsteps, ob_space, ac_space)\n",
    "    actor_critic.load('/Users/andrewszot/Downloads/weights/model_100000.ckpt')\n",
    "\n",
    "    total_reward = 0\n",
    "    step = 0\n",
    "\n",
    "    while not done:\n",
    "        states = np.expand_dims(states, 0)\n",
    "        actions, values, _ = actor_critic.act(states)\n",
    "\n",
    "        states, reward, done, _ = env.step(actions[0])\n",
    "\n",
    "        total_reward += reward\n",
    "        \n",
    "        displayImage(states, step, total_reward)\n",
    "        step += 1\n",
    "\n",
    "    print('total reward', total_reward)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
